

import pandas as pd
import numpy as np

csv_path = r"C:\Temp\Unlearning\Data Appendix\Cifar 10 Resnet Saliency\compiled_results.csv"
df = pd.read_csv(csv_path)

mia_columns = [
    "Forget vs Retain Membership Inference Attack (MIA)",
    "Forget vs Test Membership Inference Attack (MIA)",
    "Test vs Retain Membership Inference Attack (MIA)"
]

def compute_fi(B, R, M):
    denom = abs(B - R)
    return 0 if denom == 0 else (abs(B - R) - abs(M - R)) / denom

def compute_musi(f_i, alpha=13.8):
    return 100 * (1 / (1 + np.exp(-alpha * (f_i - 0.5))))

baseline_df = df[df['unlearning'] == 'baseline']
retrain_df = df[df['unlearning'] == 'retrain']

final_df = df.copy()
    
for col in mia_columns:
    f_col = f"f_i ({col})"
    mus_col = f"MUS_i ({col})"

    def calculate(row):
        cond = (baseline_df['seed'] == row['seed']) & \
               (baseline_df['dataset'] == row['dataset']) & \
               (baseline_df['model'] == row['model'])
        B_row = baseline_df[cond]
        B = B_row[col].values[0] if not B_row.empty else np.nan

        cond = (retrain_df['seed'] == row['seed']) & \
               (retrain_df['dataset'] == row['dataset']) & \
               (retrain_df['model'] == row['model'])
        R_row = retrain_df[cond]
        R = R_row[col].values[0] if not R_row.empty else np.nan

        M = row[col]
        if pd.isna(B) or pd.isna(R) or pd.isna(M):
            return pd.Series([np.nan, np.nan])
        f_i = compute_fi(B, R, M)
        mus_i = compute_musi(f_i)
        return pd.Series([f_i, mus_i])

    final_df[[f_col, mus_col]] = final_df.apply(calculate, axis=1)

mus_cols = [f"MUS_i ({col})" for col in mia_columns]
final_df["MIAU"] = final_df[mus_cols].mean(axis=1)

final_df.to_csv( r"C:\Temp\Unlearning\Data Appendix\Cifar 10 Resnet Saliency\compiled_results_MIAU.csv", index=False)

print(final_df[["Filename", "unlearning", "seed", "MIAU"] + mus_cols].head())



#method_df = df[~df['unlearning'].isin(['baseline', 'retrain'])].copy()
# for col in mia_columns:
#     f_col = f"f_i ({col})"
#     mus_col = f"MUS_i ({col})"

#     merged = method_df.merge(
#         baseline_df[['seed', 'dataset', 'model', col]].rename(columns={col: 'B'}),
#         on=['seed', 'dataset', 'model']
#     ).merge(
#         retrain_df[['seed', 'dataset', 'model', col]].rename(columns={col: 'R'}),
#         on=['seed', 'dataset', 'model']
#     )

#     merged[f_col] = merged.apply(lambda row: compute_fi(row['B'], row['R'], row[col]), axis=1)
#     merged[mus_col] = merged[f_col].apply(lambda f: compute_musi(f, alpha=5))

#     final_df = final_df.merge(merged[['Filename', f_col, mus_col]], on='Filename', how='left')